# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from collections import OrderedDict

import onnx
import tensorrt as trt
from onnx import TensorProto, helper

import tensorrt_llm
from tensorrt_llm.builder import Builder
from tensorrt_llm.functional import assertion, shape
from tensorrt_llm.network import net_guard


def trt_dtype_to_onnx(dtype):
    if dtype == trt.float16:
        return TensorProto.DataType.FLOAT16
    elif dtype == trt.float32:
        return TensorProto.DataType.FLOAT
    elif dtype == trt.int32:
        return TensorProto.DataType.INT32
    else:
        raise TypeError("%s is not supported" % dtype)


def to_onnx(network, path):
    inputs = []
    for i in range(network.num_inputs):
        network_input = network.get_input(i)
        inputs.append(
            helper.make_tensor_value_info(
                network_input.name, trt_dtype_to_onnx(network_input.dtype),
                list(network_input.shape)))

    outputs = []
    for i in range(network.num_outputs):
        network_output = network.get_output(i)
        outputs.append(
            helper.make_tensor_value_info(
                network_output.name, trt_dtype_to_onnx(network_output.dtype),
                list(network_output.shape)))

    nodes = []
    for i in range(network.num_layers):
        layer = network.get_layer(i)
        layer_inputs = []
        for j in range(layer.num_inputs):
            ipt = layer.get_input(j)
            if ipt is not None:
                layer_inputs.append(layer.get_input(j).name)
        layer_outputs = [
            layer.get_output(j).name for j in range(layer.num_outputs)
        ]
        nodes.append(
            helper.make_node(str(layer.type),
                             name=layer.name,
                             inputs=layer_inputs,
                             outputs=layer_outputs,
                             domain="com.nvidia"))

    onnx_model = helper.make_model(helper.make_graph(nodes,
                                                     'attention',
                                                     inputs,
                                                     outputs,
                                                     initializer=None),
                                   producer_name='NVIDIA')
    onnx.save(onnx_model, path)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size',
                        type=int,
                        default=1,
                        help='world size, only support tensor parallelism now')
    parser.add_argument('--dtype', type=str, default='float32')
    parser.add_argument('--log_level', type=str, default='info')
    parser.add_argument('--vocab_size', type=int, default=51200)
    parser.add_argument('--n_layer', type=int, default=24)
    parser.add_argument('--n_positions', type=int, default=1024)
    parser.add_argument('--n_embd', type=int, default=1024)
    parser.add_argument('--n_head', type=int, default=16)
    parser.add_argument('--hidden_act', type=str, default='gelu')
    parser.add_argument('--max_batch_size', type=int, default=256)
    parser.add_argument('--max_input_len', type=int, default=200)
    parser.add_argument('--max_output_len', type=int, default=200)
    parser.add_argument('--use_gpt_attention_plugin',
                        default=False,
                        action='store_true')
    parser.add_argument('--use_gemm_plugin', default=False, action='store_true')
    parser.add_argument('--use_layernorm_plugin',
                        default=False,
                        action='store_true')
    parser.add_argument('--output_dir', type=str, default='gpt_outputs')
    return parser.parse_args()


def prepare_inputs(args):
    # Prepare inputs
    head_size = args.n_embd // args.n_head
    max_len = args.max_input_len + args.max_output_len
    bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size]
    inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len]
    max_len_range = [1, (max_len + 1) // 2, max_len]
    step_range = [1, 1, args.max_input_len + 1]

    input_ids = tensorrt_llm.Tensor(name='input_ids',
                                    dtype=trt.int32,
                                    shape=[-1, -1],
                                    dim_range=OrderedDict([
                                        ('batch_size', [bs_range, bs_range]),
                                        ('input_len', [inlen_range, 1]),
                                    ]))
    kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
    past_key_value = []
    sequence_length = None
    shape_tensor = None
    if not args.use_gpt_attention_plugin:
        for i in range(args.n_layer):
            kv_dim_range = OrderedDict([
                ('batch_size', [bs_range, bs_range]),
                ('num_heads', [args.n_head, args.n_head]),
                ('past_key_len', [0, max_len_range]),
                ('head_size', [head_size, head_size]),
            ])
            k = tensorrt_llm.Tensor(name=f'past_key_{i}',
                                    dtype=kv_dtype,
                                    shape=[-1, args.n_head, -1, head_size],
                                    dim_range=kv_dim_range)
            v = tensorrt_llm.Tensor(name=f'past_value_{i}',
                                    dtype=kv_dtype,
                                    shape=[-1, args.n_head, -1, head_size],
                                    dim_range=kv_dim_range)
            past_key_value.append((k, v))
            # TODO(kaiyu): Remove this when TRT fix the named dimension
            assertion(shape(input_ids, 0) == shape(k, 0), 'batch size')
            assertion(shape(k, 2) == shape(v, 2), 'kv cache len')
    else:
        for i in range(args.n_layer):
            past_key_value.append(
                tensorrt_llm.Tensor(
                    name=f'past_{i}',
                    dtype=kv_dtype,
                    shape=[2, -1, args.n_head, -1, head_size],
                    dim_range=OrderedDict([
                        ('2', [2, 2]), ('batch_size', [bs_range, bs_range]),
                        ('num_heads', [args.n_head, args.n_head]),
                        ('past_key_len', [max_len_range, max_len_range]),
                        ('head_size', [head_size, head_size])
                    ]),
                ))
        sequence_length = tensorrt_llm.Tensor(
            name='sequence_length',
            dtype=trt.int32,
            shape=[-1],
            dim_range=OrderedDict([('batch_size', [bs_range, bs_range])]),
        )

        shape_tensor = tensorrt_llm.Tensor(
            name='shape_tensor',
            dtype=trt.int32,
            shape=[-1, -1],
            dim_range=OrderedDict([('step', [step_range, max_len_range]),
                                   ('cur_seq_len', [0, max_len_range])]))
    return (input_ids, None, past_key_value, sequence_length, shape_tensor,
            True)


if __name__ == '__main__':
    args = parse_arguments()
    tensorrt_llm.logger.set_level(args.log_level)
    tensorrt_llm.set_default_dtype(args.dtype)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32

    builder = Builder()

    # Initialize Module
    apply_query_key_layer_scaling = False
    tensorrt_llm_gpt = tensorrt_llm.models.GPTLMHeadModel(
        num_layers=args.n_layer,
        num_heads=args.n_head,
        hidden_size=args.n_embd,
        vocab_size=args.vocab_size,
        hidden_act=args.hidden_act,
        max_position_embeddings=args.n_positions,
        dtype=kv_dtype,
        tensor_parallel=args.world_size,  # TP only
        tensor_parallel_group=list(range(args.world_size)),  # TP only
        apply_query_key_layer_scaling=apply_query_key_layer_scaling)

    # Module -> Network
    network = builder.create_network()
    if args.use_gpt_attention_plugin:
        network.plugin_config.set_gpt_attention_plugin()
    if args.use_gemm_plugin:
        network.plugin_config.set_gemm_plugin()
    if args.use_layernorm_plugin:
        network.plugin_config.set_layernorm_plugin()
    with net_guard(network):
        # Prepare
        network.set_named_parameters(tensorrt_llm_gpt.named_parameters())

        # Forward
        inputs = prepare_inputs(args)
        lm_logits, presents = tensorrt_llm_gpt(*inputs)

        # Mark outputs
        lm_logits.mark_output('logits', kv_dtype)
        if not args.use_gpt_attention_plugin:
            for i, present in enumerate(presents):
                k, v = present
                k.mark_output(f'present_key_{i}', kv_dtype)
                v.mark_output(f'present_value_{i}', kv_dtype)
        else:
            for i, present in enumerate(presents):
                present.mark_output(f'present_{i}', kv_dtype)

    model_path = os.path.join(args.output_dir, 'test.onnx')
    to_onnx(network.trt_network, model_path)
